今天是第29天我們可以寫一個基於人工智慧與深度學習對斑馬魚做行為分析的系統,以下是我寫得最有效率的程式碼
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
import numpy as np
from yolov8 import YOLOv8
from sort import Sort
# 自定義 CNN 特徵提取器
class CNNFeatureExtractor(nn.Module):
def __init__(self):
super(CNNFeatureExtractor, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(256 * 16 * 16, 1024)
self.fc2 = nn.Linear(1024, 512)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 256 * 16 * 16)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
# 多層 LSTM 模型
class MultiLayerLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=3):
super(MultiLayerLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
c0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# YOLOv8 模型初始化
yolo_model = YOLOv8("yolov8_weights.pth")
# SORT 跟踪器初始化
sort_tracker = Sort()
# 初始化 CNN 特徵提取器
cnn_feature_extractor = CNNFeatureExtractor()
# LSTM 模型參數
input_size = 512 # 來自 CNN 提取的特徵向量
hidden_size = 1024
output_size = 5 # 行為分類數量
num_layers = 3
# 初始化 LSTM 模型
lstm_model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
lstm_model.load_state_dict(torch.load("lstm_model.h5"))
lstm_model.eval()
# 損失函數與優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(cnn_feature_extractor.parameters()) + list(lstm_model.parameters()), lr=0.0001)
# 使用 GPU 設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_feature_extractor.to(device)
lstm_model.to(device)
# 視頻處理
cap = cv2.VideoCapture("zebrafish_video.mp4")
def extract_cnn_features(bbox, frame):
"""
從給定的邊框和圖像框架中提取 CNN 特徵
:param bbox: 邊框 (x1, y1, x2, y2)
:param frame: 當前的圖像幀
:return: 提取的 CNN 特徵數據
"""
x1, y1, x2, y2 = bbox
roi = frame[y1:y2, x1:x2]
roi_resized = cv2.resize(roi, (64, 64))
# 數據增強: 隨機翻轉、旋轉
if np.random.rand() > 0.5:
roi_resized = cv2.flip(roi_resized, 1)
angle = np.random.uniform(-10, 10)
M = cv2.getRotationMatrix2D((32, 32), angle, 1.0)
roi_resized = cv2.warpAffine(roi_resized, M, (64, 64))
roi_tensor = torch.tensor(roi_resized, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
cnn_features = cnn_feature_extractor(roi_tensor)
return cnn_features
def visualize_results(frame, bbox, behavior_class):
"""
將結果可視化,將行為分類結果疊加在視頻幀上。
:param frame: 當前的圖像幀
:param bbox: 邊框 (x1, y1, x2, y2)
:param behavior_class: 行為分類的標籤
"""
x1, y1, x2, y2 = bbox
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"Behavior: {behavior_class}"
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# 斑馬魚檢測
detections = yolo_model.detect(frame)
# 斑馬魚跟蹤
tracks = sort_tracker.update(detections)
sequence_features = []
for track in tracks:
fish_id, bbox, _ = track
cnn_features = extract_cnn_features(bbox, frame)
sequence_features.append(cnn_features.squeeze(0).cpu().numpy())
if len(sequence_features) > 0:
sequence_features_tensor = torch.tensor(sequence_features, dtype=torch.float32).unsqueeze(0).to(device)
# LSTM 模型預測
with torch.no_grad():
behavior_output = lstm_model(sequence_features_tensor)
# 獲取行為分類結果
_, behavior_class = torch.max(behavior_output, 1)
behavior_class = behavior_class.item()
# 可視化結果
for i, track in enumerate(tracks):
visualize_results(frame, track[1], behavior_class)
cv2.imshow("Advanced Zebrafish Behavior Analysis", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
class CNNFeatureExtractor(nn.Module):
def __init__(self):
super(CNNFeatureExtractor, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(256 * 16 * 16, 1024)
self.fc2 = nn.Linear(1024, 512)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 256 * 16 * 16)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
解釋:
forward
函數是定義前向傳播過程的地方,數據流經卷積層、池化層,最終經過全連接層。class MultiLayerLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=3):
super(MultiLayerLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
c0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
解釋:
input_size
是每個時間步輸入的特徵數量(即來自CNN的特徵向量大小)。hidden_size
是LSTM單元的隱藏層大小,output_size
是最終行為分類的類別數。forward
函數執行前向傳播,LSTM的初始隱藏狀態和細胞狀態都初始化為零向量,然後輸入數據進行LSTM計算,最後通過全連接層輸出分類結果。# YOLOv8 模型初始化
yolo_model = YOLOv8("yolov8_weights.pth")
# SORT 跟踪器初始化
sort_tracker = Sort()
# 初始化 CNN 特徵提取器
cnn_feature_extractor = CNNFeatureExtractor()
# LSTM 模型參數
input_size = 512 # 來自 CNN 提取的特徵向量
hidden_size = 1024
output_size = 5 # 行為分類數量
num_layers = 3
# 初始化 LSTM 模型
lstm_model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
lstm_model.load_state_dict(torch.load("lstm_model.h5"))
lstm_model.eval()
# 損失函數與優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(cnn_feature_extractor.parameters()) + list(lstm_model.parameters()), lr=0.0001)
# 使用 GPU 設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_feature_extractor.to(device)
lstm_model.to(device)
# 視頻處理
cap = cv2.VideoCapture("zebrafish_video.mp4")
解釋:
criterion
是交叉熵損失函數,optimizer
是Adam優化器,用於訓練模型。cap
用於讀取視頻文件。def extract_cnn_features(bbox, frame):
x1, y1, x2, y2 = bbox
roi = frame[y1:y2, x1:x2]
roi_resized = cv2.resize(roi, (64, 64))
# 數據增強: 隨機翻轉、旋轉
if np.random.rand() > 0.5:
roi_resized = cv2.flip(roi_resized, 1)
angle = np.random.uniform(-10, 10)
M = cv2.getRotationMatrix2D((32, 32), angle, 1.0)
roi_resized = cv2.warpAffine(roi_resized, M, (64, 64))
roi_tensor = torch.tensor(roi_resized, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
cnn_features = cnn_feature_extractor(roi_tensor)
return cnn_features
解釋:
roi_resized
是將ROI調整到64x64像素的大小。def visualize_results(frame, bbox, behavior_class):
x1, y1, x2, y2 = bbox
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"Behavior: {behavior_class}"
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# 斑馬魚檢測
detections = yolo_model.detect(frame)
# 斑馬魚跟蹤
tracks = sort_tracker.update(detections)
sequence_features = []
for track in tracks:
fish_id, bbox, _ = track
cnn_features = extract_cnn_features(bbox, frame)
sequence_features.append(cnn_features.squeeze(0).cpu().numpy())
if len(sequence_features) > 0:
sequence_features_tensor = torch.tensor(sequence_features, dtype=torch.float32).unsqueeze(0).to(device)
# LSTM 模型預測
with torch.no_grad():
behavior_output = lstm_model(sequence_features_tensor)
# 獲取行為分類結果
_, behavior_class = torch.max(behavior_output, 1)
behavior_class = behavior_class.item()
# 可視化結果
for i, track in enumerate(tracks):
visualize_results(frame, track[1], behavior_class)
cv2.imshow("Advanced Zebrafish Behavior Analysis", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
解釋:
visualize_results
用於將行為分類結果疊加到視頻幀上,在畫面上顯示斑馬魚的行為類別。
while cap.isOpened()
是主循環,負責處理視頻中的每一幀。q
鍵則退出。這個程式碼結合了物體檢測、跟踪、特徵提取、行為分類和結果可視化,形成了一個完整的斑馬魚行為分析系統。模型的複雜性體現在自定義的CNN、使用多層LSTM捕捉時序依賴、以及數據增強和GPU加速上。